{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# 6. Azure SQL Memory\n",
    "\n",
    "The memory AzureSQL database can be thought of as a normalized source of truth. The memory module is the primary way pyrit keeps track of requests and responses to targets and scores. Most of this is done automatically. All attacks write to memory for later retrieval. All scorers also write to memory when scoring.\n",
    "\n",
    "The schema is found in `memory_models.py` and can be programmatically viewed as follows\n",
    "\n",
    "## Azure Login\n",
    "\n",
    "PyRIT `AzureSQLMemory` supports only **Azure Entra ID authentication** at this time. User ID/password-based login is not available.\n",
    "\n",
    "Please log in to your Azure account before running this notebook:\n",
    "\n",
    "- Log in with the proper scope to obtain the correct access token:\n",
    "  ```bash\n",
    "  az login --scope https://database.windows.net//.default\n",
    "  ```\n",
    "## Environment Variables\n",
    "\n",
    "Please set the following environment variables to run AzureSQLMemory interactions:\n",
    "\n",
    "- `AZURE_SQL_DB_CONNECTION_STRING` = \"<Azure SQL DB connection string here in SQLAlchemy format>\"\n",
    "- `AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL` = \"<Azure Storage Account results container URL>\" (which uses delegation SAS) but needs login to Azure.\n",
    "\n",
    "To use regular key-based authentication, please also set:\n",
    "\n",
    "- `AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Schema for AttackResultEntries:\n",
      "  Column id (UNIQUEIDENTIFIER)\n",
      "  Column conversation_id (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column objective (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column attack_identifier (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column objective_sha256 (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column last_response_id (UNIQUEIDENTIFIER)\n",
      "  Column last_score_id (UNIQUEIDENTIFIER)\n",
      "  Column executed_turns (INTEGER)\n",
      "  Column execution_time_ms (INTEGER)\n",
      "  Column outcome (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column outcome_reason (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column attack_metadata (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column timestamp (DATETIME)\n",
      "  Column pruned_conversation_ids (VARCHAR(4000) COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column adversarial_chat_conversation_ids (VARCHAR(4000) COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "Schema for PromptMemoryEntries:\n",
      "  Column id (UNIQUEIDENTIFIER)\n",
      "  Column role (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column conversation_id (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column sequence (INTEGER)\n",
      "  Column timestamp (DATETIME)\n",
      "  Column labels (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column prompt_metadata (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column converter_identifiers (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column prompt_target_identifier (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column attack_identifier (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column response_error (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column original_value_data_type (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column original_value (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column original_value_sha256 (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column converted_value_data_type (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column converted_value (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column converted_value_sha256 (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column original_prompt_id (UNIQUEIDENTIFIER)\n",
      "Schema for ScoreEntries:\n",
      "  Column id (UNIQUEIDENTIFIER)\n",
      "  Column score_value (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column score_value_description (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column score_type (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column score_category (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column score_rationale (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column score_metadata (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column scorer_class_identifier (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column prompt_request_response_id (UNIQUEIDENTIFIER)\n",
      "  Column timestamp (DATETIME)\n",
      "  Column task (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column objective (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "Schema for EmbeddingData:\n",
      "  Column id (UNIQUEIDENTIFIER)\n",
      "  Column embedding (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column embedding_type_name (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "Schema for SeedPromptEntries:\n",
      "  Column id (UNIQUEIDENTIFIER)\n",
      "  Column value (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column data_type (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column name (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column dataset_name (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column harm_categories (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column description (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column authors (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column groups (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column source (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column date_added (DATETIME)\n",
      "  Column added_by (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column prompt_metadata (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column parameters (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
      "  Column prompt_group_id (UNIQUEIDENTIFIER)\n",
      "  Column sequence (INTEGER)\n",
      "  Column value_sha256 (NVARCHAR(4000) COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n"
     ]
    }
   ],
   "source": [
    "from pyrit.memory import CentralMemory\n",
    "from pyrit.setup import AZURE_SQL, initialize_pyrit\n",
    "\n",
    "initialize_pyrit(memory_db_type=AZURE_SQL)\n",
    "\n",
    "memory = CentralMemory.get_memory_instance()\n",
    "memory.print_schema()  # type: ignore"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "## Basic Azure SQL Memory Programming Usage\n",
    "\n",
    "The `pyrit.memory.azure_sql_memory` module provides functionality to keep track of the conversation history, scoring, data, and more using Azure SQL. You can use memory to read and write data. Here is an example that retrieves a normalized conversation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{}: user: Hi, chat bot! This is my initial prompt.\n",
      "{}: assistant: Nice to meet you! This is my response.\n",
      "{}: user: Wonderful! This is my second prompt to the chat bot!\n"
     ]
    }
   ],
   "source": [
    "from uuid import uuid4\n",
    "\n",
    "from pyrit.models import Message, MessagePiece\n",
    "\n",
    "conversation_id = str(uuid4())\n",
    "\n",
    "message_list = [\n",
    "    MessagePiece(\n",
    "        role=\"user\", original_value=\"Hi, chat bot! This is my initial prompt.\", conversation_id=conversation_id\n",
    "    ),\n",
    "    MessagePiece(\n",
    "        role=\"assistant\", original_value=\"Nice to meet you! This is my response.\", conversation_id=conversation_id\n",
    "    ),\n",
    "    MessagePiece(\n",
    "        role=\"user\",\n",
    "        original_value=\"Wonderful! This is my second prompt to the chat bot!\",\n",
    "        conversation_id=conversation_id,\n",
    "    ),\n",
    "]\n",
    "\n",
    "memory.add_message_to_memory(request=Message([message_list[0]]))\n",
    "memory.add_message_to_memory(request=Message([message_list[1]]))\n",
    "memory.add_message_to_memory(request=Message([message_list[2]]))\n",
    "\n",
    "\n",
    "entries = memory.get_conversation(conversation_id=conversation_id)\n",
    "\n",
    "for entry in entries:\n",
    "    print(entry)\n",
    "\n",
    "\n",
    "# Cleanup memory resources\n",
    "memory.dispose_engine()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
